import os

os.environ["CUDA_VISIBLE_DEVICES"] = '2,3'
from langchain.agents import create_sql_agent
from langchain.agents.agent_toolkits import SQLDatabaseToolkit
from langchain.agents.agent_types import AgentType
from langchain.sql_database import SQLDatabase

from langchain.llms import VLLM

model_dir = '/datasets/fengjiahao/nlp/TongyiFinance/Tongyi-Finance-14B'

llm = VLLM(
    model=model_dir,
    trust_remote_code=True,  # mandatory for hf models
    max_new_tokens=128,
    top_k=10,
    top_p=0.95,
    temperature=0.8,
    tensor_parallel_size=2,
)

db = SQLDatabase.from_uri(
    "sqlite:////datasets/fengjiahao/nlp/bs_challenge_financial_14b_dataset/dataset/博金杯比赛数据.db",

    sample_rows_in_table_info=2)
toolkit = SQLDatabaseToolkit(db=db, llm=llm)

agent_executor = create_sql_agent(
    llm=llm,
    toolkit=toolkit,
    verbose=True,
    handle_parsing_errors=True,
    agent_type=AgentType.ZERO_SHOT_REACT_DESCRIPTION, return_intermediate_steps=True
)

answer = agent_executor.run(
    "请帮我计算，在20210105，中信行业分类划分的一级行业为综合金融行业中，涨跌幅最大股票的股票代码是？涨跌幅是多少？百分数保留两位小数。股票涨跌幅定义为：（收盘价 - 前一日收盘价 / 前一日收盘价）* 100%。"
)
print(answer)
print()
